Newer
Older
BlackoutClient / Assets / Best HTTP / Source / SecureProtocol / crypto / tls / RecordStream.cs
#if !BESTHTTP_DISABLE_ALTERNATE_SSL && (!UNITY_WEBGL || UNITY_EDITOR)
#pragma warning disable
using System;
using System.IO;

using BestHTTP.SecureProtocol.Org.BouncyCastle.Utilities;
using BestHTTP.SecureProtocol.Org.BouncyCastle.Utilities.IO;

namespace BestHTTP.SecureProtocol.Org.BouncyCastle.Crypto.Tls
{
    /// <summary>An implementation of the TLS 1.0/1.1/1.2 record layer, allowing downgrade to SSLv3.</summary>
    internal sealed class RecordStream
    {
        private const int DEFAULT_PLAINTEXT_LIMIT = (1 << 14);

        internal const int TLS_HEADER_SIZE = 5;
        internal const int TLS_HEADER_TYPE_OFFSET = 0;
        internal const int TLS_HEADER_VERSION_OFFSET = 1;
        internal const int TLS_HEADER_LENGTH_OFFSET = 3;

        private TlsProtocol mHandler;
        private Stream mInput;
        private Stream mOutput;
        private TlsCompression mPendingCompression = null, mReadCompression = null, mWriteCompression = null;
        private TlsCipher mPendingCipher = null, mReadCipher = null, mWriteCipher = null;
        private SequenceNumber mReadSeqNo = new SequenceNumber(), mWriteSeqNo = new SequenceNumber();
        private MemoryStream mBuffer = new MemoryStream();

        private TlsHandshakeHash mHandshakeHash = null;
        private readonly BaseOutputStream mHandshakeHashUpdater;

        private ProtocolVersion mReadVersion = null, mWriteVersion = null;
        private bool mRestrictReadVersion = true;

        private int mPlaintextLimit, mCompressedLimit, mCiphertextLimit;

        internal RecordStream(TlsProtocol handler, Stream input, Stream output)
        {
            this.mHandler = handler;
            this.mInput = input;
            this.mOutput = output;
            this.mReadCompression = new TlsNullCompression();
            this.mWriteCompression = this.mReadCompression;
            this.mHandshakeHashUpdater = new HandshakeHashUpdateStream(this);
        }

        internal /*virtual*/ void Init(TlsContext context)
        {
            this.mReadCipher = new TlsNullCipher(context);
            this.mWriteCipher = this.mReadCipher;
            this.mHandshakeHash = new DeferredHash();
            this.mHandshakeHash.Init(context);

            SetPlaintextLimit(DEFAULT_PLAINTEXT_LIMIT);
        }

        internal /*virtual*/ int GetPlaintextLimit()
        {
            return mPlaintextLimit;
        }

        internal /*virtual*/ void SetPlaintextLimit(int plaintextLimit)
        {
            this.mPlaintextLimit = plaintextLimit;
            this.mCompressedLimit = this.mPlaintextLimit + 1024;
            this.mCiphertextLimit = this.mCompressedLimit + 1024;
        }

        internal /*virtual*/ ProtocolVersion ReadVersion
        {
            get { return mReadVersion; }
            set { this.mReadVersion = value; }
        }

        internal /*virtual*/ void SetWriteVersion(ProtocolVersion writeVersion)
        {
            this.mWriteVersion = writeVersion;
        }

        /**
         * RFC 5246 E.1. "Earlier versions of the TLS specification were not fully clear on what the
         * record layer version number (TLSPlaintext.version) should contain when sending ClientHello
         * (i.e., before it is known which version of the protocol will be employed). Thus, TLS servers
         * compliant with this specification MUST accept any value {03,XX} as the record layer version
         * number for ClientHello."
         */
        internal /*virtual*/ void SetRestrictReadVersion(bool enabled)
        {
            this.mRestrictReadVersion = enabled;
        }

        internal /*virtual*/ void SetPendingConnectionState(TlsCompression tlsCompression, TlsCipher tlsCipher)
        {
            this.mPendingCompression = tlsCompression;
            this.mPendingCipher = tlsCipher;
        }

        internal /*virtual*/ void SentWriteCipherSpec()
        {
            if (mPendingCompression == null || mPendingCipher == null)
                throw new TlsFatalAlert(AlertDescription.handshake_failure);

            this.mWriteCompression = this.mPendingCompression;
            this.mWriteCipher = this.mPendingCipher;
            this.mWriteSeqNo = new SequenceNumber();
        }

        internal /*virtual*/ void ReceivedReadCipherSpec()
        {
            if (mPendingCompression == null || mPendingCipher == null)
                throw new TlsFatalAlert(AlertDescription.handshake_failure);

            this.mReadCompression = this.mPendingCompression;
            this.mReadCipher = this.mPendingCipher;
            this.mReadSeqNo = new SequenceNumber();
        }

        internal /*virtual*/ void FinaliseHandshake()
        {
            if (mReadCompression != mPendingCompression || mWriteCompression != mPendingCompression
                || mReadCipher != mPendingCipher || mWriteCipher != mPendingCipher)
            {
                throw new TlsFatalAlert(AlertDescription.handshake_failure);
            }
            this.mPendingCompression = null;
            this.mPendingCipher = null;
        }

        internal /*virtual*/ void CheckRecordHeader(byte[] recordHeader)
        {
            byte type = TlsUtilities.ReadUint8(recordHeader, TLS_HEADER_TYPE_OFFSET);

            /*
             * RFC 5246 6. If a TLS implementation receives an unexpected record type, it MUST send an
             * unexpected_message alert.
             */
            CheckType(type, AlertDescription.unexpected_message);

            if (!mRestrictReadVersion)
            {
                int version = TlsUtilities.ReadVersionRaw(recordHeader, TLS_HEADER_VERSION_OFFSET);
                if ((version & 0xffffff00) != 0x0300)
                    throw new TlsFatalAlert(AlertDescription.illegal_parameter);
            }
            else
            {
                ProtocolVersion version = TlsUtilities.ReadVersion(recordHeader, TLS_HEADER_VERSION_OFFSET);
                if (mReadVersion == null)
                {
                    // Will be set later in 'readRecord'
                }
                else if (!version.Equals(mReadVersion))
                {
                    throw new TlsFatalAlert(AlertDescription.illegal_parameter);
                }
            }

            int length = TlsUtilities.ReadUint16(recordHeader, TLS_HEADER_LENGTH_OFFSET);

            CheckLength(length, mCiphertextLimit, AlertDescription.record_overflow);
        }

        internal /*virtual*/ bool ReadRecord()
        {
            byte[] recordHeader = TlsUtilities.ReadAllOrNothing(TLS_HEADER_SIZE, mInput);
            if (recordHeader == null)
                return false;

            byte type = TlsUtilities.ReadUint8(recordHeader, TLS_HEADER_TYPE_OFFSET);

            /*
             * RFC 5246 6. If a TLS implementation receives an unexpected record type, it MUST send an
             * unexpected_message alert.
             */
            CheckType(type, AlertDescription.unexpected_message);

            if (!mRestrictReadVersion)
            {
                int version = TlsUtilities.ReadVersionRaw(recordHeader, TLS_HEADER_VERSION_OFFSET);
                if ((version & 0xffffff00) != 0x0300)
                    throw new TlsFatalAlert(AlertDescription.illegal_parameter);
            }
            else
            {
                ProtocolVersion version = TlsUtilities.ReadVersion(recordHeader, TLS_HEADER_VERSION_OFFSET);
                if (mReadVersion == null)
                {
                    mReadVersion = version;
                }
                else if (!version.Equals(mReadVersion))
                {
                    throw new TlsFatalAlert(AlertDescription.illegal_parameter);
                }
            }

            int length = TlsUtilities.ReadUint16(recordHeader, TLS_HEADER_LENGTH_OFFSET);

            CheckLength(length, mCiphertextLimit, AlertDescription.record_overflow);

            byte[] plaintext = DecodeAndVerify(type, mInput, length);
            mHandler.ProcessRecord(type, plaintext, 0, plaintext.Length);
            return true;
        }

        internal /*virtual*/ byte[] DecodeAndVerify(byte type, Stream input, int len)
        {
            byte[] buf = TlsUtilities.ReadFully(len, input);

            long seqNo = mReadSeqNo.NextValue(AlertDescription.unexpected_message);
            byte[] decoded = mReadCipher.DecodeCiphertext(seqNo, type, buf, 0, buf.Length);

            CheckLength(decoded.Length, mCompressedLimit, AlertDescription.record_overflow);

            /*
             * TODO 5246 6.2.2. Implementation note: Decompression functions are responsible for
             * ensuring that messages cannot cause internal buffer overflows.
             */
            Stream cOut = mReadCompression.Decompress(mBuffer);
            if (cOut != mBuffer)
            {
                cOut.Write(decoded, 0, decoded.Length);
                cOut.Flush();
                decoded = GetBufferContents();
            }

            /*
             * RFC 5246 6.2.2. If the decompression function encounters a TLSCompressed.fragment that
             * would decompress to a length in excess of 2^14 bytes, it should report a fatal
             * decompression failure error.
             */
            CheckLength(decoded.Length, mPlaintextLimit, AlertDescription.decompression_failure);

            /*
             * RFC 5246 6.2.1 Implementations MUST NOT send zero-length fragments of Handshake, Alert,
             * or ChangeCipherSpec content types.
             */
            if (decoded.Length < 1 && type != ContentType.application_data)
                throw new TlsFatalAlert(AlertDescription.illegal_parameter);

            return decoded;
        }

        internal /*virtual*/ void WriteRecord(byte type, byte[] plaintext, int plaintextOffset, int plaintextLength)
        {
            // Never send anything until a valid ClientHello has been received
            if (mWriteVersion == null)
                return;

            /*
             * RFC 5246 6. Implementations MUST NOT send record types not defined in this document
             * unless negotiated by some extension.
             */
            CheckType(type, AlertDescription.internal_error);

            /*
             * RFC 5246 6.2.1 The length should not exceed 2^14.
             */
            CheckLength(plaintextLength, mPlaintextLimit, AlertDescription.internal_error);

            /*
             * RFC 5246 6.2.1 Implementations MUST NOT send zero-length fragments of Handshake, Alert,
             * or ChangeCipherSpec content types.
             */
            if (plaintextLength < 1 && type != ContentType.application_data)
                throw new TlsFatalAlert(AlertDescription.internal_error);

            Stream cOut = mWriteCompression.Compress(mBuffer);

            long seqNo = mWriteSeqNo.NextValue(AlertDescription.internal_error);

            byte[] ciphertext;
            if (cOut == mBuffer)
            {
                ciphertext = mWriteCipher.EncodePlaintext(seqNo, type, plaintext, plaintextOffset, plaintextLength);
            }
            else
            {
                cOut.Write(plaintext, plaintextOffset, plaintextLength);
                cOut.Flush();
                byte[] compressed = GetBufferContents();

                /*
                 * RFC 5246 6.2.2. Compression must be lossless and may not increase the content length
                 * by more than 1024 bytes.
                 */
                CheckLength(compressed.Length, plaintextLength + 1024, AlertDescription.internal_error);

                ciphertext = mWriteCipher.EncodePlaintext(seqNo, type, compressed, 0, compressed.Length);
            }

            /*
             * RFC 5246 6.2.3. The length may not exceed 2^14 + 2048.
             */
            CheckLength(ciphertext.Length, mCiphertextLimit, AlertDescription.internal_error);

            int recordLength = ciphertext.Length + TLS_HEADER_SIZE;
            byte[] record = BestHTTP.PlatformSupport.Memory.BufferPool.Get(recordLength, true);
            TlsUtilities.WriteUint8(type, record, TLS_HEADER_TYPE_OFFSET);
            TlsUtilities.WriteVersion(mWriteVersion, record, TLS_HEADER_VERSION_OFFSET);
            TlsUtilities.WriteUint16(ciphertext.Length, record, TLS_HEADER_LENGTH_OFFSET);
            Array.Copy(ciphertext, 0, record, TLS_HEADER_SIZE, ciphertext.Length);
            mOutput.Write(record, 0, recordLength);
            BestHTTP.PlatformSupport.Memory.BufferPool.Release(record);
            mOutput.Flush();
        }

        internal /*virtual*/ void NotifyHelloComplete()
        {
            this.mHandshakeHash = mHandshakeHash.NotifyPrfDetermined();
        }

        internal /*virtual*/ TlsHandshakeHash HandshakeHash
        {
            get { return mHandshakeHash; }
        }

        internal /*virtual*/ Stream HandshakeHashUpdater
        {
            get { return mHandshakeHashUpdater; }
        }

        internal /*virtual*/ TlsHandshakeHash PrepareToFinish()
        {
            TlsHandshakeHash result = mHandshakeHash;
            this.mHandshakeHash = mHandshakeHash.StopTracking();
            return result;
        }

        internal /*virtual*/ void SafeClose()
        {
            try
            {
                BestHTTP.SecureProtocol.Org.BouncyCastle.Utilities.Platform.Dispose(mInput);
            }
            catch (IOException)
            {
            }

            try
            {
                BestHTTP.SecureProtocol.Org.BouncyCastle.Utilities.Platform.Dispose(mOutput);
            }
            catch (IOException)
            {
            }
        }

        internal /*virtual*/ void Flush()
        {
            mOutput.Flush();
        }

        private byte[] GetBufferContents()
        {
            byte[] contents = mBuffer.ToArray();
            mBuffer.SetLength(0);
            return contents;
        }

        private static void CheckType(byte type, byte alertDescription)
        {
            switch (type)
            {
            case ContentType.application_data:
            case ContentType.alert:
            case ContentType.change_cipher_spec:
            case ContentType.handshake:
            //case ContentType.heartbeat:
                break;
            default:
                throw new TlsFatalAlert(alertDescription);
            }
        }

        private static void CheckLength(int length, int limit, byte alertDescription)
        {
            if (length > limit)
                throw new TlsFatalAlert(alertDescription);
        }

        private class HandshakeHashUpdateStream
            : BaseOutputStream
        {
            private readonly RecordStream mOuter;
            public HandshakeHashUpdateStream(RecordStream mOuter)
            {
                this.mOuter = mOuter;
            }

            public override void Write(byte[] buf, int off, int len)
            {
                mOuter.mHandshakeHash.BlockUpdate(buf, off, len);
            }
        }

        private class SequenceNumber
        {
            private long value = 0L;
            private bool exhausted = false;

            internal long NextValue(byte alertDescription)
            {
                if (exhausted)
                {
                    throw new TlsFatalAlert(alertDescription);
                }
                long result = value;
                if (++value == 0)
                {
                    exhausted = true;
                }
                return result;
            }
        }
    }
}
#pragma warning restore
#endif